#pragma once

#include <windows.h>
#include <PathCch.h>

#include "catch.hpp"

#include <wil/filesystem.h>
#include <wil/result.h>
#include <roerrorapi.h>
#include <wil/com.h>

#include "mocking.h"

#define REPORTS_ERROR(expr) \
    witest::ReportsError(wistd::is_same<HRESULT, decltype(expr)>{}, [&]() { \
        return expr; \
    })
#define REQUIRE_ERROR(expr) REQUIRE(REPORTS_ERROR(expr))
#define REQUIRE_NOERROR(expr) REQUIRE_FALSE(REPORTS_ERROR(expr))

#define CRASHES(expr) \
    witest::DoesCodeCrash([&]() { \
        return expr; \
    })
#define REQUIRE_CRASH(expr) REQUIRE(CRASHES(expr))
#define REQUIRE_NOCRASH(expr) REQUIRE_FALSE(CRASHES(expr))

// NOTE: SUCCEEDED/FAILED macros not used here since Catch2 can give us better diagnostics if it knows the HRESULT value
#define REQUIRE_SUCCEEDED(expr) REQUIRE((HRESULT)(expr) >= 0)
#define REQUIRE_FAILED(expr) REQUIRE((HRESULT)(expr) < 0)

// MACRO double evaluation check.
// The following macro illustrates a common problem with writing macros:
//      #define MY_MAX(a, b) (((a) > (b)) ? (a) : (b))
// The issue is that whatever code is being used for both a and b is being executed twice.
// This isn't harmful when thinking of constant numerics, but consider this example:
//      MY_MAX(4, InterlockedIncrement(&cCount))
// This evaluates the (B) parameter twice and results in incrementing the counter twice.
// We use MDEC in unit tests to verify that this kind of pattern is not present.  A test
// of this kind:
//      MY_MAX(MDEC(4), MDEC(InterlockedIncrement(&cCount))
// will verify that the parameters are not evaluated more than once.
#define MDEC(PARAM) (witest::details::MacroDoubleEvaluationCheck(__LINE__, #PARAM), PARAM)

// There's some functionality that we need for testing that's not available for the app partition. Since those tests are
// primarily compilation tests, declare what's needed here
extern "C"
{

#if !WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP | WINAPI_PARTITION_SYSTEM | WINAPI_PARTITION_GAMES)
    WINBASEAPI _Ret_maybenull_
    PVOID WINAPI AddVectoredExceptionHandler(_In_ ULONG First, _In_ PVECTORED_EXCEPTION_HANDLER Handler);

    WINBASEAPI
    ULONG WINAPI RemoveVectoredExceptionHandler(_In_ PVOID Handle);
#endif
}

#pragma warning(push)
#pragma warning(disable : 4702) // Unreachable code

namespace witest
{
namespace details
{
    inline void MacroDoubleEvaluationCheck(size_t uLine, _In_ const char* pszCode)
    {
        struct SEval
        {
            size_t uLine;
            const char* pszCode;
        };

        static SEval rgEval[15] = {};
        static size_t nOffset = 0;

        for (auto& eval : rgEval)
        {
            if ((eval.uLine == uLine) && (eval.pszCode != nullptr) && (0 == strcmp(pszCode, eval.pszCode)))
            {
                // This verification indicates that macro-double-evaluation check is firing for a particular usage of MDEC().
                FAIL("Expression '" << pszCode << "' double evaluated in macro on line " << uLine);
            }
        }

        rgEval[nOffset].uLine = uLine;
        rgEval[nOffset].pszCode = pszCode;
        nOffset = (nOffset + 1) % ARRAYSIZE(rgEval);
    }

    template <typename T>
    class AssignTemporaryValueCleanup
    {
    public:
        AssignTemporaryValueCleanup(_In_ AssignTemporaryValueCleanup const&) = delete;
        AssignTemporaryValueCleanup& operator=(_In_ AssignTemporaryValueCleanup const&) = delete;

        explicit AssignTemporaryValueCleanup(_Inout_ T* pVal, T val) WI_NOEXCEPT : m_pVal(pVal), m_valOld(*pVal)
        {
            *pVal = val;
        }

        AssignTemporaryValueCleanup(_Inout_ AssignTemporaryValueCleanup&& other) WI_NOEXCEPT : m_pVal(other.m_pVal),
                                                                                               m_valOld(other.m_valOld)
        {
            other.m_pVal = nullptr;
        }

        ~AssignTemporaryValueCleanup() WI_NOEXCEPT
        {
            operator()();
        }

        void operator()() WI_NOEXCEPT
        {
            if (m_pVal != nullptr)
            {
                *m_pVal = m_valOld;
                m_pVal = nullptr;
            }
        }

        void Dismiss() WI_NOEXCEPT
        {
            m_pVal = nullptr;
        }

    private:
        T* m_pVal;
        T m_valOld;
    };
} // namespace details

// Use the following routine to allow for a variable to be swapped with another and automatically revert the
// assignment at the end of the scope.
// Example:
//      int nFoo = 10
//      {
//          auto revert = witest::AssignTemporaryValue(&nFoo, 12);
//          // nFoo will now be 12 within this scope...
//      }
//      // and nFoo is back to 10 within the outer scope
template <typename T>
inline witest::details::AssignTemporaryValueCleanup<T> AssignTemporaryValue(_Inout_ T* pVal, T val) WI_NOEXCEPT
{
    return witest::details::AssignTemporaryValueCleanup<T>(pVal, val);
}

//! Global class which tracks objects that derive from @ref AllocatedObject.
//! Use `witest::g_objectCount.Leaked()` to determine if an object deriving from `AllocatedObject` has been leaked.
class GlobalCount
{
public:
    int m_count = 0;

    //! Returns `true` if there are any objects that derive from @ref AllocatedObject still in memory.
    bool Leaked() const
    {
        return (m_count != 0);
    }

    ~GlobalCount()
    {
        if (Leaked())
        {
            // NOTE: This runs when no test is active, but will still cause an assert failure to notify
            FAIL_FAST_MSG("GlobalCount is non-zero; there is a leak somewhere");
        }
    }
};
__declspec(selectany) GlobalCount g_objectCount;

//! Derive an allocated test object from witest::AllocatedObject to ensure that those objects aren't leaked in the test.
//! Note that you can call g_objectCount.Leaked() at any point to determine if a leak has already occurred (assuming that
//! all objects should have been destroyed at that point.
class AllocatedObject
{
public:
    AllocatedObject()
    {
        g_objectCount.m_count++;
    }
    ~AllocatedObject()
    {
        g_objectCount.m_count--;
    }
};

template <typename Lambda>
bool DoesCodeThrow(Lambda&& callOp)
{
#ifdef WIL_ENABLE_EXCEPTIONS
    try
#endif
    {
        callOp();
    }
#ifdef WIL_ENABLE_EXCEPTIONS
    catch (...)
    {
        return true;
    }
#endif

    return false;
}

template <typename TLambda>
bool DoesCodeCrash(TLambda&& lambda)
{
    bool crashed = false;
    __try
    {
        // for the purposes of this test, throwing an exception is not a crash
        DoesCodeThrow(wistd::forward<TLambda>(lambda));
    }
    __except (EXCEPTION_EXECUTE_HANDLER)
    {
        crashed = true;
    }
    return crashed;
}

template <typename TLambda>
bool DoesCodeFailFast(TLambda&& callOp)
{
    bool failFast = false;

    // NOTE: 'ReportFailure_Base<wil::FailureType::FailFast, false>' is the entrypoint for failfasts, however this just
    // calls ReportFailure_NoReturn and under optimization, primarily on arm64, this can turn into just a jump
    // instruction. The end result is that the function is too "small" to detour. Detouring 'ReportFailure_NoReturn'
    // gives us effectively the same result
    witest::detoured_thread_function<&wil::details::ReportFailure_NoReturn<wil::FailureType::FailFast>> detour;
    REQUIRE_SUCCEEDED(detour.reset(
        [&](__R_FN_PARAMS_FULL, const wil::details::ResultStatus& resultPair, PCWSTR message, wil::details::ReportFailureOptions options) {
            failFast = true;
            wil::details::ReportFailure_Base<wil::FailureType::FailFast, true>(__R_FN_CALL_FULL, resultPair, message, options);
        }));

    callOp();

    return failFast;
}

template <typename Lambda>
bool ReportsError(wistd::false_type, Lambda&& callOp)
{
    // Non-HRESULT; expect throwing or fail-fast
#ifdef WIL_ENABLE_EXCEPTIONS
    try
    {
#endif
        return DoesCodeFailFast(wistd::forward<Lambda>(callOp));
#ifdef WIL_ENABLE_EXCEPTIONS
    }
    catch (...)
    {
        return true;
    }
#endif
}

template <typename Lambda>
bool ReportsError(wistd::true_type, Lambda&& callOp)
{
    // HRESULT; expect error code
    return FAILED(callOp());
}

#ifdef WIL_ENABLE_EXCEPTIONS
class TestFailureCache final : public wil::details::IFailureCallback
{
public:
    TestFailureCache() : m_callbackHolder(this)
    {
    }

    void clear()
    {
        m_failures.clear();
    }

    size_t size() const
    {
        return m_failures.size();
    }

    bool empty() const
    {
        return m_failures.empty();
    }

    const wil::FailureInfo& operator[](size_t pos) const
    {
        return m_failures.at(pos).GetFailureInfo();
    }

    // IFailureCallback
    // NOLINTNEXTLINE(bugprone-exception-escape): Test code; let it crash
    bool NotifyFailure(wil::FailureInfo const& failure) WI_NOEXCEPT override
    {
        m_failures.emplace_back(failure);
        return false;
    }

private:
    std::vector<wil::StoredFailureInfo> m_failures;
    wil::details::ThreadFailureCallbackHolder m_callbackHolder;
};
#endif

inline HRESULT GetTempFileName(wchar_t (&result)[MAX_PATH])
{
    wchar_t dir[MAX_PATH];
    RETURN_LAST_ERROR_IF(::GetTempPathW(MAX_PATH, dir) == 0);
    RETURN_LAST_ERROR_IF(::GetTempFileNameW(dir, L"wil", 0, result) == 0);
    return S_OK;
}

inline HRESULT CreateUniqueFolderPath(wchar_t (&buffer)[MAX_PATH], PCWSTR root = nullptr)
{
    if (root)
    {
        RETURN_LAST_ERROR_IF(::GetTempFileNameW(root, L"wil", 0, buffer) == 0);
    }
    else
    {
        wchar_t tempPath[MAX_PATH];
        RETURN_LAST_ERROR_IF(::GetTempPathW(ARRAYSIZE(tempPath), tempPath) == 0);
        RETURN_LAST_ERROR_IF(::GetLongPathNameW(tempPath, tempPath, ARRAYSIZE(tempPath)) == 0);
        RETURN_LAST_ERROR_IF(::GetTempFileNameW(tempPath, L"wil", 0, buffer) == 0);
    }
    RETURN_IF_WIN32_BOOL_FALSE(DeleteFileW(buffer));
    PathCchRemoveExtension(buffer, ARRAYSIZE(buffer));
    return S_OK;
}

#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) && (_WIN32_WINNT >= _WIN32_WINNT_WIN8)
inline void RequireRestrictedErrorInfo(HRESULT error, wchar_t const* message)
{
    wil::com_ptr_nothrow<IRestrictedErrorInfo> errorInfo;
    REQUIRE(GetRestrictedErrorInfo(&errorInfo) == S_OK); // S_FALSE means no restricted error info for the thread
    REQUIRE(errorInfo != nullptr);
    wil::unique_bstr description;
    wil::unique_bstr restrictedDescription;
    wil::unique_bstr capabilitySid;
    HRESULT errorCode;
    REQUIRE_SUCCEEDED(errorInfo->GetErrorDetails(&description, &errorCode, &restrictedDescription, &capabilitySid));
    REQUIRE(errorCode == error);
    REQUIRE(wcscmp(restrictedDescription.get(), message) == 0);
}
#endif

#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) && (_WIN32_WINNT >= _WIN32_WINNT_WIN7)

struct TestFolder
{
    TestFolder()
    {
        if (SUCCEEDED(CreateUniqueFolderPath(m_path)) && SUCCEEDED(wil::CreateDirectoryDeepNoThrow(m_path)))
        {
            m_valid = true;
        }
    }

    TestFolder(PCWSTR path)
    {
        if (SUCCEEDED(StringCchCopyW(m_path, ARRAYSIZE(m_path), path)) && SUCCEEDED(wil::CreateDirectoryDeepNoThrow(m_path)))
        {
            m_valid = true;
        }
    }

    TestFolder(const TestFolder&) = delete;
    TestFolder& operator=(const TestFolder&) = delete;

    TestFolder(TestFolder&& other) WI_NOEXCEPT
    {
        if (other.m_valid)
        {
            m_valid = true;
            other.m_valid = false;
            wcscpy_s(m_path, other.m_path);
        }
    }

    ~TestFolder()
    {
        if (m_valid)
        {
            wil::RemoveDirectoryRecursiveNoThrow(m_path);
        }
    }

    operator bool() const
    {
        return m_valid;
    }

    operator PCWSTR() const
    {
        return m_path;
    }

    PCWSTR Path() const
    {
        return m_path;
    }

private:
    bool m_valid = false;
    wchar_t m_path[MAX_PATH] = L"";
};

struct TestFile
{
    TestFile(PCWSTR path)
    {
        if (SUCCEEDED(StringCchCopyW(m_path, ARRAYSIZE(m_path), path)))
        {
            Create();
        }
    }

    TestFile(PCWSTR dirPath, PCWSTR fileName)
    {
        if (SUCCEEDED(StringCchCopyW(m_path, ARRAYSIZE(m_path), dirPath)) && SUCCEEDED(PathCchAppend(m_path, ARRAYSIZE(m_path), fileName)))
        {
            Create();
        }
    }

    TestFile(const TestFile&) = delete;
    TestFile& operator=(const TestFile&) = delete;

    TestFile(TestFile&& other) WI_NOEXCEPT
    {
        if (other.m_valid)
        {
            m_valid = true;
            m_deleteDir = other.m_deleteDir;
            other.m_valid = other.m_deleteDir = false;
            wcscpy_s(m_path, other.m_path);
        }
    }

    ~TestFile()
    {
        // Best effort on all of these
        if (m_valid)
        {
            ::DeleteFileW(m_path);
        }
        if (m_deleteDir)
        {
            size_t parentLength;
            if (wil::try_get_parent_path_range(m_path, &parentLength))
            {
                m_path[parentLength] = L'\0';
                ::RemoveDirectoryW(m_path);
                m_path[parentLength] = L'\\';
            }
        }
    }

    operator bool() const
    {
        return m_valid;
    }

    operator PCWSTR() const
    {
        return m_path;
    }

    PCWSTR Path() const
    {
        return m_path;
    }

private:
    HRESULT Create()
    {
        WI_ASSERT(!m_valid && !m_deleteDir);
        wil::unique_hfile fileHandle(::CreateFileW(
            m_path, FILE_WRITE_ATTRIBUTES, FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, nullptr, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, nullptr));
        if (!fileHandle)
        {
            auto err = ::GetLastError();
            size_t parentLength;
            if ((err == ERROR_PATH_NOT_FOUND) && wil::try_get_parent_path_range(m_path, &parentLength))
            {
                m_path[parentLength] = L'\0';
                RETURN_IF_FAILED(wil::CreateDirectoryDeepNoThrow(m_path));
                m_deleteDir = true;

                m_path[parentLength] = L'\\';
                fileHandle.reset(::CreateFileW(
                    m_path, FILE_WRITE_ATTRIBUTES, FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, nullptr, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, nullptr));
                RETURN_LAST_ERROR_IF(!fileHandle);
            }
            else
            {
                RETURN_WIN32(err);
            }
        }

        m_valid = true;
        return S_OK;
    }

    bool m_valid = false;
    bool m_deleteDir = false;
    wchar_t m_path[MAX_PATH] = L"";
};

#endif
} // namespace witest

#pragma warning(pop)
